﻿#include <iostream>
#include <vector>
using namespace std;

class Solution {
public:
    int findKth(const vector<int>& a, int sta, const vector<int>& b, int stb, int kth) {
        if (sta >= a.size()) return b[stb + kth - 1];
        if (stb >= b.size()) return a[sta + kth - 1];
        if (kth == 1) return min(a[sta], b[stb]);

        int h = kth / 2;
        int vala = a.size() - sta >= h ? a[sta + h - 1] : a.back();
        int counta = a.size() - sta >= h ? h : a.size() - sta;

        int valb = b.size() - stb >= h ? b[stb + h - 1] : b.back();
        int countb = b.size() - stb >= h ? h : b.size() - stb;

        if (vala < valb) {
            return findKth(a, sta + counta, b, stb, kth - counta);
        }
        return findKth(a, sta, b, stb + countb, kth - countb);
    }
    double findMedianSortedArrays(vector<int>& a, vector<int>& b) {
        int n = a.size();
        int m = b.size();

        int t = n + m;

        int id0 = (t + 1) / 2;
        int id1 = (t + 2) / 2;

        int val0 = findKth(a, 0, b, 0, id0);
        int val1 = findKth(a, 0, b, 0, id1);

        return 1.0 * (val0 + val1) / 2;
    }
};

int main() {
    vector<int> nums1 = { 1, 3 };
    vector<int> nums2 = { 2 };
    Solution solution;
    double result = solution.findMedianSortedArrays(nums1, nums2);
    cout << "两个有序数组的中位数为: " << result << endl;
    return 0;
}